The goal of a model is to provide a simple low-dimensional summary of a dataset. In the context of this book we’re going to use models to partition data into patterns and residuals. Strong patterns will hide subtler trends, so we’ll use models to help peel back layers of structure as we explore a dataset.

There are two parts to a model:

  1. define a family of models: express a precise, but generic, pattern that you want to capture.
    • straight line: y = a_1 * x + a_2
    • quadratic curve: y = a_1 * x ^ a_2
  2. generate a fitted model: finding the model from the family that is the closest to your data. This takes the generic model family and makes it specific.
    • y = 3 * x + 7
    • y = 9 * x ^ 2

It’s important to understand that a fitted model is just the closest model from a family of models. That implies that you have the “best” model (according to some criteria); it doesn’t imply that you have a good model and it certainly doesn’t imply that the model is “true”. George Box puts this well in his famous aphorism:

All models are wrong, but some are useful. The goal of a model is not to uncover truth, but to discover a simple approximation that is still useful.

A simple model

sim1 contains two continuous variables, x and y. Let’s plot them to see hwo they’re related:

sim1
## # A tibble: 30 x 2
##        x     y
##    <int> <dbl>
##  1     1  4.20
##  2     1  7.51
##  3     1  2.13
##  4     2  8.99
##  5     2 10.2 
##  6     2 11.3 
##  7     3  7.36
##  8     3 10.5 
##  9     3 10.5 
## 10     4 12.4 
## # … with 20 more rows
ggplot(sim1, aes(x, y))+
  geom_point()

You can see a strong pattern in the data. Let’s use a model to capture that pattern and make it explicit.

It’s our job to supply the basic form of the model. In this case, the relationship looks linear, i.e. y = a_0 + a_1 * x.

Let’s start by getting a feel for what models from that family look like by randomly generating a few and overlaying them on the data. For this simple case, we can use geom_abline() which takes a slope and intercept as parameters.

# models are defined by a1 and a2
# a1 is intercept, a2 is slope
# totally 250 models
models <- tibble(
  a1 = runif(250, -20, 40),
  a2 = runif(250, -5, 5)
)

models
## # A tibble: 250 x 2
##        a1    a2
##     <dbl> <dbl>
##  1   9.41 -3.94
##  2  30.4   1.19
##  3 -14.3  -4.53
##  4 -14.9  -2.89
##  5 -18.5  -1.35
##  6  24.1   4.45
##  7  -4.43  3.06
##  8 -11.4  -1.62
##  9  -4.14  2.08
## 10  33.0   1.85
## # … with 240 more rows
ggplot(sim1, aes(x, y))+
  geom_abline(aes(intercept = a1, slope = a2),
              data = models,
              alpha = 1/4)+
  geom_point()

There are 250 models on this plot, but a lot are really bad! We need to find the good models by making precise our intuition that a good model is “close” to the data.

We need a way to quantify the distance between the data and a model. Then we can fit the model by finding the value of a_1 and a_2 that generate the model with the smallest distance from this data.

One easy place to start is to find the vertical distance between each point and the model. This distance is just the difference between the y value given by the model (the prediction), and the actual y value in the data (the response).

To compute this distance:

  1. we first turn our model family into an R function:
  • input
    • model parameters
    • data
  • output
    • values predicted by the model
# turn our model family into an R function
model1 <- function(a, data) {
  a[1] + data$x * a[2]
}

model1(c(7, 1.5), sim1)
##  [1]  8.5  8.5  8.5 10.0 10.0 10.0 11.5 11.5 11.5 13.0 13.0 13.0 14.5 14.5
## [15] 14.5 16.0 16.0 16.0 17.5 17.5 17.5 19.0 19.0 19.0 20.5 20.5 20.5 22.0
## [29] 22.0 22.0
  1. Next, we need some way to compute an overall distance between the predicted and actual values.

In other words, the plot above shows 30 distances: how do we collapse that into a single number?

One common way to do this in statistics to use the “root-mean-squared deviation”:

  1. we compute the difference between actual and predicted,
  2. then square them
  3. then average them
  4. then take the square root
model1 <- function(a, data) {
  a[1] + data$x * a[2]
}

# residual = actual - predicted
# predicted need model parameter and data
measure_distance <- function(mod, data) {
  diff <- data$y - model1(mod, data)
  sqrt(mean(diff ^ 2))
}

measure_distance(c(7, 1.5), sim1)
## [1] 2.665212
  1. Now we can use purrr to compute the distance for all 250 models defined above.

We need a helper function because our distance function expects the model as a numeric vector of length 2.

model1 <- function(a, data) {
  a[1] + data$x * a[2]
}

measure_distance <- function(mod, data) {
  diff <- data$y - model1(mod, data)
  sqrt(mean(diff ^ 2))
}

# helper function
sim1_dist <- function(a1, a2) {
  measure_distance(c(a1, a2), sim1)
}

models <- models %>% 
  mutate(dist = purrr::map2_dbl(a1, a2, sim1_dist))

models
## # A tibble: 250 x 3
##        a1    a2  dist
##     <dbl> <dbl> <dbl>
##  1   9.41 -3.94 32.7 
##  2  30.4   1.19 21.7 
##  3 -14.3  -4.53 58.0 
##  4 -14.9  -2.89 48.5 
##  5 -18.5  -1.35 42.6 
##  6  24.1   4.45 33.8 
##  7  -4.43  3.06  4.75
##  8 -11.4  -1.62 37.4 
##  9  -4.14  2.08  8.48
## 10  33.0   1.85 27.7 
## # … with 240 more rows
  1. Next, let’s overlay the 10 best models on to the data.

I’ve coloured the models by -dist: this is an easy way to make sure that the best models (i.e. the ones with the smallest distance) get the brighest colours.

ggplot(sim1, aes(x, y))+
  geom_point(size = 2, colour = "grey30")+
  geom_abline(
    aes(intercept = a1, slope = a2, colour = -dist),
    data = filter(models, rank(dist) <= 10)
  )

Models as observations

We can also think about these models as observations, and visualising with a scatterplot of a1 vs a2, again coloured by -dist.

We can no longer directly see how the model compares to the data, but we can see many models at once.

Again, I’ve highlighted the 10 best models, this time by drawing red circles underneath them.

ggplot(models, aes(a1, a2))+
  geom_point(data = filter(models, rank(dist) <= 10),
             size = 4,
             colour = "red") +
  geom_point(aes(colour = -dist))

Instead of trying lots of random models, we could be more systematic and generate an evenly spaced grid of points (this is called a grid search). I picked the parameters of the grid roughly by looking at where the best models were in the plot above.

grid <- expand.grid(
  a1 = seq(-5, 20, length = 25),
  a2 = seq(1, 3, length = 25)) %>% 
  mutate(dist = purrr::map2_dbl(a1, a2, sim1_dist))

grid
##             a1       a2      dist
## 1   -5.0000000 1.000000 15.452475
## 2   -3.9583333 1.000000 14.443171
## 3   -2.9166667 1.000000 13.438807
## 4   -1.8750000 1.000000 12.440580
## 5   -0.8333333 1.000000 11.450094
## 6    0.2083333 1.000000 10.469547
## 7    1.2500000 1.000000  9.502017
## 8    2.2916667 1.000000  8.551921
## 9    3.3333333 1.000000  7.625781
## 10   4.3750000 1.000000  6.733488
## 11   5.4166667 1.000000  5.890443
## 12   6.4583333 1.000000  5.121026
## 13   7.5000000 1.000000  4.463479
## 14   8.5416667 1.000000  3.973729
## 15   9.5833333 1.000000  3.718673
## 16  10.6250000 1.000000  3.746556
## 17  11.6666667 1.000000  4.051540
## 18  12.7083333 1.000000  4.578581
## 19  13.7500000 1.000000  5.261366
## 20  14.7916667 1.000000  6.047370
## 21  15.8333333 1.000000  6.901415
## 22  16.8750000 1.000000  7.801187
## 23  17.9166667 1.000000  8.732562
## 24  18.9583333 1.000000  9.686429
## 25  20.0000000 1.000000 10.656749
## 26  -5.0000000 1.083333 14.961504
## 27  -3.9583333 1.083333 13.950902
## 28  -2.9166667 1.083333 12.945226
## 29  -1.8750000 1.083333 11.945720
## 30  -0.8333333 1.083333 10.954072
## 31   0.2083333 1.083333  9.972629
## 32   1.2500000 1.083333  9.004726
## 33   2.2916667 1.083333  8.055246
## 34   3.3333333 1.083333  7.131552
## 35   4.3750000 1.083333  6.245095
## 36   5.4166667 1.083333  5.414197
## 37   6.4583333 1.083333  4.668617
## 38   7.5000000 1.083333  4.055685
## 39   8.5416667 1.083333  3.642982
## 40   9.5833333 1.083333  3.502027
## 41  10.6250000 1.083333  3.664315
## 42  11.6666667 1.083333  4.093941
## 43  12.7083333 1.083333  4.718437
## 44  13.7500000 1.083333  5.471478
## 45  14.7916667 1.083333  6.307190
## 46  15.8333333 1.083333  7.196829
## 47  16.8750000 1.083333  8.122696
## 48  17.9166667 1.083333  9.073708
## 49  18.9583333 1.083333 10.042724
## 50  20.0000000 1.083333 11.024998
## 51  -5.0000000 1.166667 14.472350
## 52  -3.9583333 1.166667 13.460492
## 53  -2.9166667 1.166667 12.453550
## 54  -1.8750000 1.166667 11.452822
## 55  -0.8333333 1.166667 10.460090
## 56   0.2083333 1.166667  9.477867
## 57   1.2500000 1.166667  8.509793
## 58   2.2916667 1.166667  7.561306
## 59   3.3333333 1.166667  6.640802
## 60   4.3750000 1.166667  5.761709
## 61   5.4166667 1.166667  4.946157
## 62   6.4583333 1.166667  4.231050
## 63   7.5000000 1.166667  3.675492
## 64   8.5416667 1.166667  3.359589
## 65   9.5833333 1.166667  3.351801
## 66  10.6250000 1.166667  3.654100
## 67  11.6666667 1.166667  4.200055
## 68  12.7083333 1.166667  4.909034
## 69  13.7500000 1.166667  5.720743
## 70  14.7916667 1.166667  6.597373
## 71  15.8333333 1.166667  7.516242
## 72  16.8750000 1.166667  8.463605
## 73  17.9166667 1.166667  9.430878
## 74  18.9583333 1.166667 10.412514
## 75  20.0000000 1.166667 11.404804
## 76  -5.0000000 1.250000 13.985205
## 77  -3.9583333 1.250000 12.972153
## 78  -2.9166667 1.250000 11.964016
## 79  -1.8750000 1.250000 10.962151
## 80  -0.8333333 1.250000  9.968448
## 81   0.2083333 1.250000  8.985617
## 82   1.2500000 1.250000  8.017655
## 83   2.2916667 1.250000  7.070673
## 84   3.3333333 1.250000  6.154363
## 85   4.3750000 1.250000  5.284703
## 86   5.4166667 1.250000  4.488889
## 87   6.4583333 1.250000  3.813437
## 88   7.5000000 1.250000  3.332360
## 89   8.5416667 1.250000  3.136412
## 90   9.5833333 1.250000  3.277145
## 91  10.6250000 1.250000  3.716505
## 92  11.6666667 1.250000  4.365236
## 93  12.7083333 1.250000  5.144735
## 94  13.7500000 1.250000  6.004286
## 95  14.7916667 1.250000  6.914097
## 96  15.8333333 1.250000  7.856728
## 97  16.8750000 1.250000  8.821663
## 98  17.9166667 1.250000  9.802318
## 99  18.9583333 1.250000 10.794410
## 100 20.0000000 1.250000 11.795053
## 101 -5.0000000 1.333333 13.500287
## 102 -3.9583333 1.333333 12.486128
## 103 -2.9166667 1.333333 11.476898
## 104 -1.8750000 1.333333 10.474021
## 105 -0.8333333 1.333333  9.479514
## 106  0.2083333 1.333333  8.496316
## 107  1.2500000 1.333333  7.528860
## 108  2.2916667 1.333333  6.584088
## 109  3.3333333 1.333333  5.673345
## 110  4.3750000 1.333333  4.815974
## 111  5.4166667 1.333333  4.046048
## 112  6.4583333 1.333333  3.423090
## 113  7.5000000 1.333333  3.038869
## 114  8.5416667 1.333333  2.986979
## 115  9.5833333 1.333333  3.283215
## 116 10.6250000 1.333333  3.847999
## 117 11.6666667 1.333333  4.583103
## 118 12.7083333 1.333333  5.419659
## 119 13.7500000 1.333333  6.317493
## 120 14.7916667 1.333333  7.253887
## 121 15.8333333 1.333333  8.215666
## 122 16.8750000 1.333333  9.194868
## 123 17.9166667 1.333333 10.186470
## 124 18.9583333 1.333333 11.187174
## 125 20.0000000 1.333333 12.194741
## 126 -5.0000000 1.416667 13.017843
## 127 -3.9583333 1.416667 12.002697
## 128 -2.9166667 1.416667 10.992516
## 129 -1.8750000 1.416667  9.988803
## 130 -0.8333333 1.416667  8.993727
## 131  0.2083333 1.416667  8.010505
## 132  1.2500000 1.416667  7.044103
## 133  2.2916667 1.416667  6.102519
## 134  3.3333333 1.416667  5.199252
## 135  4.3750000 1.416667  4.358193
## 136  5.4166667 1.416667  3.622929
## 137  6.4583333 1.416667  3.070425
## 138  7.5000000 1.416667  2.810614
## 139  8.5416667 1.416667  2.922624
## 140  9.5833333 1.416667  3.369577
## 141 10.6250000 1.416667  4.041845
## 142 11.6666667 1.416667  4.846556
## 143 12.7083333 1.416667  5.728162
## 144 13.7500000 1.416667  6.656179
## 145 14.7916667 1.416667  7.613654
## 146 15.8333333 1.416667  8.590744
## 147 16.8750000 1.416667  9.581449
## 148 17.9166667 1.416667 10.581947
## 149 18.9583333 1.416667 11.589701
## 150 20.0000000 1.416667 12.602971
## 151 -5.0000000 1.500000 12.538160
## 152 -3.9583333 1.500000 11.522188
## 153 -2.9166667 1.500000 10.511248
## 154 -1.8750000 1.500000  9.506944
## 155 -0.8333333 1.500000  8.511626
## 156  0.2083333 1.500000  7.528858
## 157  1.2500000 1.500000  6.564280
## 158  2.2916667 1.500000  5.627253
## 159  3.3333333 1.500000  4.734166
## 160  4.3750000 1.500000  3.915203
## 161  5.4166667 1.500000  3.227296
## 162  6.4583333 1.500000  2.769874
## 163  7.5000000 1.500000  2.664414
## 164  8.5416667 1.500000  2.948922
## 165  9.5833333 1.500000  3.530343
## 166 10.6250000 1.500000  4.289597
## 167 11.6666667 1.500000  5.148601
## 168 12.7083333 1.500000  6.065121
## 169 13.7500000 1.500000  7.016654
## 170 14.7916667 1.500000  7.990701
## 171 15.8333333 1.500000  8.979940
## 172 16.8750000 1.500000  9.979853
## 173 17.9166667 1.500000 10.987527
## 174 18.9583333 1.500000 12.001008
## 175 20.0000000 1.500000 13.018938
## 176 -5.0000000 1.583333 12.061566
## 177 -3.9583333 1.583333 11.044982
## 178 -2.9166667 1.583333 10.033543
## 179 -1.8750000 1.583333  9.028981
## 180 -0.8333333 1.583333  8.033876
## 181  0.2083333 1.583333  7.052230
## 182  1.2500000 1.583333  6.090556
## 183  2.2916667 1.583333  5.160034
## 184  3.3333333 1.583333  4.281023
## 185  4.3750000 1.583333  3.492635
## 186  5.4166667 1.583333  2.870537
## 187  6.4583333 1.583333  2.540002
## 188  7.5000000 1.583333  2.614072
## 189  8.5416667 1.583333  3.063539
## 190  9.5833333 1.583333  3.755970
## 191 10.6250000 1.583333  4.582520
## 192 11.6666667 1.583333  5.482865
## 193 12.7083333 1.583333  6.426062
## 194 13.7500000 1.583333  7.395733
## 195 14.7916667 1.583333  8.382697
## 196 15.8333333 1.583333  9.381496
## 197 16.8750000 1.583333 10.388719
## 198 17.9166667 1.583333 11.402133
## 199 18.9583333 1.583333 12.420223
## 200 20.0000000 1.583333 13.441925
## 201 -5.0000000 1.666667 11.588444
## 202 -3.9583333 1.666667 10.571525
## 203 -2.9166667 1.666667  9.559936
## 204 -1.8750000 1.666667  8.555568
## 205 -0.8333333 1.666667  7.561300
## 206  0.2083333 1.666667  6.581711
## 207  1.2500000 1.666667  5.624474
## 208  2.2916667 1.666667  4.703258
## 209  3.3333333 1.666667  3.844048
## 210  4.3750000 1.666667  3.098856
## 211  5.4166667 1.666667  2.568902
## 212  6.4583333 1.666667  2.401196
## 213  7.5000000 1.666667  2.665026
## 214  8.5416667 1.666667  3.257166
## 215  9.5833333 1.666667  4.035595
## 216 10.6250000 1.666667  4.912542
## 217 11.6666667 1.666667  5.843821
## 218 12.7083333 1.666667  6.807170
## 219 13.7500000 1.666667  7.790701
## 220 14.7916667 1.666667  8.787640
## 221 15.8333333 1.666667  9.793894
## 222 16.8750000 1.666667 10.806860
## 223 17.9166667 1.666667 11.824815
## 224 18.9583333 1.666667 12.846571
## 225 20.0000000 1.666667 13.871290
## 226 -5.0000000 1.750000 11.119237
## 227 -3.9583333 1.750000 10.102345
## 228 -2.9166667 1.750000  9.091066
## 229 -1.8750000 1.750000  8.087504
## 230 -0.8333333 1.750000  7.094934
## 231  0.2083333 1.750000  6.118709
## 232  1.2500000 1.750000  5.168100
## 233  2.2916667 1.750000  4.260287
## 234  3.3333333 1.750000  3.429427
## 235  4.3750000 1.750000  2.746278
## 236  5.4166667 1.750000  2.343768
## 237  6.4583333 1.750000  2.369514
## 238  7.5000000 1.750000  2.811775
## 239  8.5416667 1.750000  3.516775
## 240  9.5833333 1.750000  4.358838
## 241 10.6250000 1.750000  5.272700
## 242 11.6666667 1.750000  6.226830
## 243 12.7083333 1.750000  7.205247
## 244 13.7500000 1.750000  8.199263
## 245 14.7916667 1.750000  9.203823
## 246 15.8333333 1.750000 10.215819
## 247 16.8750000 1.750000 11.233241
## 248 17.9166667 1.750000 12.254737
## 249 18.9583333 1.750000 13.279367
## 250 20.0000000 1.750000 14.306458
## 251 -5.0000000 1.833333 10.654461
## 252 -3.9583333 1.833333  9.638068
## 253 -2.9166667 1.833333  8.627706
## 254 -1.8750000 1.833333  7.625772
## 255 -0.8333333 1.833333  6.636086
## 256  0.2083333 1.833333  5.665069
## 257  1.2500000 1.833333  4.724248
## 258  2.2916667 1.833333  3.835906
## 259  3.3333333 1.833333  3.046304
## 260  4.3750000 1.833333  2.452732
## 261  5.4166667 1.833333  2.218550
## 262  6.4583333 1.833333  2.449116
## 263  7.5000000 1.833333  3.040480
## 264  8.5416667 1.833333  3.828969
## 265  9.5833333 1.833333  4.716739
## 266 10.6250000 1.833333  5.657242
## 267 11.6666667 1.833333  6.628068
## 268 12.7083333 1.833333  7.617633
## 269 13.7500000 1.833333  8.619484
## 270 14.7916667 1.833333  9.629789
## 271 15.8333333 1.833333 10.646139
## 272 16.8750000 1.833333 11.666957
## 273 17.9166667 1.833333 12.691163
## 274 18.9583333 1.833333 13.717999
## 275 20.0000000 1.833333 14.746915
## 276 -5.0000000 1.916667 10.194722
## 277 -3.9583333 1.916667  9.179435
## 278 -2.9166667 1.916667  8.170793
## 279 -1.8750000 1.916667  7.171597
## 280 -0.8333333 1.916667  6.186429
## 281  0.2083333 1.916667  5.223231
## 282  1.2500000 1.916667  4.296803
## 283  2.2916667 1.916667  3.437009
## 284  3.3333333 1.916667  2.708077
## 285  4.3750000 1.916667  2.241533
## 286  5.4166667 1.916667  2.210294
## 287  6.4583333 1.916667  2.629918
## 288  7.5000000 1.916667  3.334318
## 289  8.5416667 1.916667  4.181988
## 290  9.5833333 1.916667  5.102010
## 291 10.6250000 1.916667  6.061529
## 292 11.6666667 1.916667  7.044423
## 293 12.7083333 1.916667  8.042126
## 294 13.7500000 1.916667  9.049742
## 295 14.7916667 1.916667 10.064294
## 296 15.8333333 1.916667 11.083877
## 297 16.8750000 1.916667 12.107221
## 298 17.9166667 1.916667 13.133445
## 299 18.9583333 1.916667 14.161925
## 300 20.0000000 1.916667 15.192202
## 301 -5.0000000 2.000000  9.740734
## 302 -3.9583333 2.000000  8.727339
## 303 -2.9166667 2.000000  7.721472
## 304 -1.8750000 2.000000  6.726510
## 305 -0.8333333 2.000000  5.748121
## 306  0.2083333 2.000000  4.796457
## 307  1.2500000 2.000000  3.891173
## 308  2.2916667 2.000000  3.073533
## 309  3.3333333 2.000000  2.433540
## 310  4.3750000 2.000000  2.137234
## 311  5.4166667 2.000000  2.320250
## 312  6.4583333 2.000000  2.893007
## 313  7.5000000 2.000000  3.677711
## 314  8.5416667 2.000000  4.566373
## 315  9.5833333 2.000000  5.508912
## 316 10.6250000 2.000000  6.481867
## 317 11.6666667 2.000000  7.473367
## 318 12.7083333 2.000000  8.476909
## 319 13.7500000 2.000000  9.488671
## 320 14.7916667 2.000000 10.506280
## 321 15.8333333 2.000000 11.528187
## 322 16.8750000 2.000000 12.553343
## 323 17.9166667 2.000000 13.581012
## 324 18.9583333 2.000000 14.610663
## 325 20.0000000 2.000000 15.641906
## 326 -5.0000000 2.083333  9.293340
## 327 -3.9583333 2.083333  8.282848
## 328 -2.9166667 2.083333  7.281148
## 329 -1.8750000 2.083333  6.292440
## 330 -0.8333333 2.083333  5.323966
## 331  0.2083333 2.083333  4.389142
## 332  1.2500000 2.083333  3.514921
## 333  2.2916667 2.083333  2.759511
## 334  3.3333333 2.083333  2.246169
## 335  4.3750000 2.083333  2.155409
## 336  5.4166667 2.083333  2.533070
## 337  6.4583333 2.083333  3.218265
## 338  7.5000000 2.083333  4.058098
## 339  8.5416667 2.083333  4.974860
## 340  9.5833333 2.083333  5.932996
## 341 10.6250000 2.083333  6.915330
## 342 11.6666667 2.083333  7.912855
## 343 12.7083333 2.083333  8.920476
## 344 13.7500000 2.083333  9.935122
## 345 14.7916667 2.083333 10.954842
## 346 15.8333333 2.083333 11.978339
## 347 16.8750000 2.083333 13.004721
## 348 17.9166667 2.083333 14.033357
## 349 18.9583333 2.083333 15.063783
## 350 20.0000000 2.083333 16.095656
## 351 -5.0000000 2.166667  8.853540
## 352 -3.9583333 2.166667  7.847256
## 353 -2.9166667 2.166667  6.851557
## 354 -1.8750000 2.166667  5.871829
## 355 -0.8333333 2.166667  4.917627
## 356  0.2083333 2.166667  4.007227
## 357  1.2500000 2.166667  3.178494
## 358  2.2916667 2.166667  2.513548
## 359  3.3333333 2.166667  2.168677
## 360  4.3750000 2.166667  2.293149
## 361  5.4166667 2.166667  2.825605
## 362  6.4583333 2.166667  3.588829
## 363  7.5000000 2.166667  4.466037
## 364  8.5416667 2.166667  5.401983
## 365  9.5833333 2.166667  6.370831
## 366 10.6250000 2.166667  7.359599
## 367 11.6666667 2.166667  8.361222
## 368 12.7083333 2.166667  9.371581
## 369 13.7500000 2.166667 10.388125
## 370 14.7916667 2.166667 11.409203
## 371 15.8333333 2.166667 12.433697
## 372 16.8750000 2.166667 13.460827
## 373 17.9166667 2.166667 14.490032
## 374 18.9583333 2.166667 15.520900
## 375 20.0000000 2.166667 16.553121
## 376 -5.0000000 2.250000  8.422522
## 377 -3.9583333 2.250000  7.422129
## 378 -2.9166667 2.250000  6.434848
## 379 -1.8750000 2.250000  5.467785
## 380 -0.8333333 2.250000  4.533896
## 381  0.2083333 2.250000  3.658673
## 382  1.2500000 2.250000  2.895809
## 383  2.2916667 2.250000  2.357046
## 384  3.3333333 2.250000  2.212637
## 385  4.3750000 2.250000  2.531007
## 386  5.4166667 2.250000  3.175905
## 387  6.4583333 2.250000  3.992103
## 388  7.5000000 2.250000  4.894644
## 389  8.5416667 2.250000  5.843657
## 390  9.5833333 2.250000  6.819769
## 391 10.6250000 2.250000  7.812831
## 392 11.6666667 2.250000  8.817116
## 393 12.7083333 2.250000  9.829185
## 394 13.7500000 2.250000 10.846860
## 395 14.7916667 2.250000 11.868698
## 396 15.8333333 2.250000 12.893710
## 397 16.8750000 2.250000 13.921194
## 398 17.9166667 2.250000 14.950642
## 399 18.9583333 2.250000 15.981673
## 400 20.0000000 2.250000 17.014000
## 401 -5.0000000 2.333333  8.001707
## 402 -3.9583333 2.333333  7.009373
## 403 -2.9166667 2.333333  6.033691
## 404 -1.8750000 2.333333  5.084259
## 405 -0.8333333 2.333333  4.179006
## 406  0.2083333 2.333333  3.353898
## 407  1.2500000 2.333333  2.683899
## 408  2.2916667 2.333333  2.308274
## 409  3.3333333 2.333333  2.371305
## 410  4.3750000 2.333333  2.843973
## 411  5.4166667 2.333333  3.566990
## 412  6.4583333 2.333333  4.419139
## 413  7.5000000 2.333333  5.338942
## 414  8.5416667 2.333333  6.296821
## 415  9.5833333 2.333333  7.277757
## 416 10.6250000 2.333333  8.273553
## 417 11.6666667 2.333333  9.279426
## 418 12.7083333 2.333333 10.292422
## 419 13.7500000 2.333333 11.310628
## 420 14.7916667 2.333333 12.332753
## 421 15.8333333 2.333333 13.357897
## 422 16.8750000 2.333333 14.385415
## 423 17.9166667 2.333333 15.414833
## 424 18.9583333 2.333333 16.445793
## 425 20.0000000 2.333333 17.478023
## 426 -5.0000000 2.416667  7.592791
## 427 -3.9583333 2.416667  6.611303
## 428 -2.9166667 2.416667  5.651399
## 429 -1.8750000 2.416667  4.726249
## 430 -0.8333333 2.416667  3.860919
## 431  0.2083333 2.416667  3.105817
## 432  1.2500000 2.416667  2.560398
## 433  2.2916667 2.416667  2.373882
## 434  3.3333333 2.416667  2.623954
## 435  4.3750000 2.416667  3.210155
## 436  5.4166667 2.416667  3.986877
## 437  6.4583333 2.416667  4.863684
## 438  7.5000000 2.416667  5.795326
## 439  8.5416667 2.416667  6.759165
## 440  9.5833333 2.416667  7.743188
## 441 10.6250000 2.416667  8.740581
## 442 11.6666667 2.416667  9.747240
## 443 12.7083333 2.416667 10.760565
## 444 13.7500000 2.416667 11.778835
## 445 14.7916667 2.416667 12.800871
## 446 15.8333333 2.416667 13.825838
## 447 16.8750000 2.416667 14.853128
## 448 17.9166667 2.416667 15.882291
## 449 18.9583333 2.416667 16.912985
## 450 20.0000000 2.416667 17.944947
## 451 -5.0000000 2.500000  7.197802
## 452 -3.9583333 2.500000  6.230736
## 453 -2.9166667 2.500000  5.292061
## 454 -1.8750000 2.500000  4.399988
## 455 -0.8333333 2.500000  3.589432
## 456  0.2083333 2.500000  2.928871
## 457  1.2500000 2.500000  2.538245
## 458  2.2916667 2.500000  2.545040
## 459  3.3333333 2.500000  2.946508
## 460  4.3750000 2.500000  3.613409
## 461  5.4166667 2.500000  4.427379
## 462  6.4583333 2.500000  5.321351
## 463  7.5000000 2.500000  6.261151
## 464  8.5416667 2.500000  7.228927
## 465  9.5833333 2.500000  8.214798
## 466 10.6250000 2.500000  9.212956
## 467 11.6666667 2.500000 10.219801
## 468 12.7083333 2.500000 11.232999
## 469 13.7500000 2.500000 12.250973
## 470 14.7916667 2.500000 13.272624
## 471 15.8333333 2.500000 14.297164
## 472 16.8750000 2.500000 15.324013
## 473 17.9166667 2.500000 16.352737
## 474 18.9583333 2.500000 17.383002
## 475 20.0000000 2.500000 18.414550
## 476 -5.0000000 2.583333  6.819161
## 477 -3.9583333 2.583333  5.871076
## 478 -2.9166667 2.583333  4.960669
## 479 -1.8750000 2.583333  4.113038
## 480 -0.8333333 2.583333  3.375807
## 481  0.2083333 2.583333  2.836405
## 482  1.2500000 2.583333  2.620011
## 483  2.2916667 2.583333  2.802474
## 484  3.3333333 2.583333  3.318644
## 485  4.3750000 2.583333  4.042657
## 486  5.4166667 2.583333  4.882919
## 487  6.4583333 2.583333  5.789029
## 488  7.5000000 2.583333  6.734460
## 489  8.5416667 2.583333  7.704751
## 490  9.5833333 2.583333  8.691580
## 491 10.6250000 2.583333  9.689895
## 492 11.6666667 2.583333 10.696482
## 493 12.7083333 2.583333 11.709206
## 494 13.7500000 2.583333 12.726604
## 495 14.7916667 2.583333 13.747637
## 496 15.8333333 2.583333 14.771551
## 497 16.8750000 2.583333 15.797787
## 498 17.9166667 2.583333 16.825919
## 499 18.9583333 2.583333 17.855620
## 500 20.0000000 2.583333 18.886634
## 501 -5.0000000 2.666667  6.459744
## 502 -3.9583333 2.666667  5.536399
## 503 -2.9166667 2.666667  4.663184
## 504 -1.8750000 2.666667  3.874144
## 505 -0.8333333 2.666667  3.231538
## 506  0.2083333 2.666667  2.836693
## 507  1.2500000 2.666667  2.796596
## 508  2.2916667 2.666667  3.124934
## 509  3.3333333 2.666667  3.725535
## 510  4.3750000 2.666667  4.490452
## 511  5.4166667 2.666667  5.349657
## 512  6.4583333 2.666667  6.264475
## 513  7.5000000 2.666667  7.213779
## 514  8.5416667 2.666667  8.185579
## 515  9.5833333 2.666667  9.172728
## 516 10.6250000 2.666667 10.170758
## 517 11.6666667 2.666667 11.176754
## 518 12.7083333 2.666667 12.188744
## 519 13.7500000 2.666667 13.205350
## 520 14.7916667 2.666667 14.225583
## 521 15.8333333 2.666667 15.248714
## 522 16.8750000 2.666667 16.274197
## 523 17.9166667 2.666667 17.301613
## 524 18.9583333 2.666667 18.330638
## 525 20.0000000 2.666667 19.361016
## 526 -5.0000000 2.750000  6.122935
## 527 -3.9583333 2.750000  5.231503
## 528 -2.9166667 2.750000  4.406479
## 529 -1.8750000 2.750000  3.692645
## 530 -0.8333333 2.750000  3.166123
## 531  0.2083333 2.750000  2.929706
## 532  1.2500000 2.750000  3.051584
## 533  2.2916667 2.750000  3.494465
## 534  3.3333333 2.750000  4.156988
## 535  4.3750000 2.750000  4.951763
## 536  5.4166667 2.750000  5.824903
## 537  6.4583333 2.750000  6.746049
## 538  7.5000000 2.750000  7.697986
## 539  8.5416667 2.750000  8.670579
## 540  9.5833333 2.750000  9.657590
## 541 10.6250000 2.750000 10.655012
## 542 11.6666667 2.750000 11.660174
## 543 12.7083333 2.750000 12.671234
## 544 13.7500000 2.750000 13.686885
## 545 14.7916667 2.750000 14.706176
## 546 15.8333333 2.750000 15.728399
## 547 16.8750000 2.750000 16.753018
## 548 17.9166667 2.750000 17.779618
## 549 18.9583333 2.750000 18.807875
## 550 20.0000000 2.750000 19.837531
## 551 -5.0000000 2.833333  5.812668
## 552 -3.9583333 2.833333  4.961881
## 553 -2.9166667 2.833333  4.198041
## 554 -1.8750000 2.833333  3.577287
## 555 -0.8333333 2.833333  3.184423
## 556  0.2083333 2.833333  3.107130
## 557  1.2500000 2.833333  3.367210
## 558  2.2916667 2.833333  3.897703
## 559  3.3333333 2.833333  4.606106
## 560  4.3750000 2.833333  5.423142
## 561  5.4166667 2.833333  6.306733
## 562  6.4583333 2.833333  7.232525
## 563  7.5000000 2.833333  8.186214
## 564  8.5416667 2.833333  9.159089
## 565  9.5833333 2.833333 10.145633
## 566 10.6250000 2.833333 11.142216
## 567 11.6666667 2.833333 12.146366
## 568 12.7083333 2.833333 13.156351
## 569 13.7500000 2.833333 14.170924
## 570 14.7916667 2.833333 15.189165
## 571 15.8333333 2.833333 16.210383
## 572 16.8750000 2.833333 17.234049
## 573 17.9166667 2.833333 18.259752
## 574 18.9583333 2.833333 19.287165
## 575 20.0000000 2.833333 20.316030
## 576 -5.0000000 2.916667  5.533408
## 577 -3.9583333 2.916667  4.733562
## 578 -2.9166667 2.916667  4.045339
## 579 -1.8750000 2.916667  3.534552
## 580 -0.8333333 2.916667  3.285040
## 581  0.2083333 2.916667  3.355600
## 582  1.2500000 2.916667  3.728104
## 583  2.2916667 2.916667  4.325229
## 584  3.3333333 2.916667  5.068194
## 585  4.3750000 2.916667  5.902179
## 586  5.4166667 2.916667  6.793746
## 587  6.4583333 2.916667  7.722977
## 588  7.5000000 2.916667  8.677783
## 589  8.5416667 2.916667  9.650575
## 590  9.5833333 2.916667 10.636419
## 591 10.6250000 2.916667 11.631998
## 592 11.6666667 2.916667 12.635010
## 593 12.7083333 2.916667 13.643816
## 594 13.7500000 2.916667 14.657219
## 595 14.7916667 2.916667 15.674329
## 596 15.8333333 2.916667 16.694468
## 597 16.8750000 2.916667 17.717111
## 598 17.9166667 2.916667 18.741851
## 599 18.9583333 2.916667 19.768359
## 600 20.0000000 2.916667 20.796376
## 601 -5.0000000 3.000000  5.290068
## 602 -3.9583333 3.000000  4.552767
## 603 -2.9166667 3.000000  3.954833
## 604 -1.8750000 3.000000  3.567051
## 605 -0.8333333 3.000000  3.460801
## 606  0.2083333 3.000000  3.660679
## 607  1.2500000 3.000000  4.122395
## 608  2.2916667 3.000000  4.770519
## 609  3.3333333 3.000000  5.540009
## 610  4.3750000 3.000000  6.387150
## 611  5.4166667 3.000000  7.284903
## 612  6.4583333 3.000000  8.216694
## 613  7.5000000 3.000000  9.172157
## 614  8.5416667 3.000000 10.144605
## 615  9.5833333 3.000000 11.129586
## 616 10.6250000 3.000000 12.124047
## 617 11.6666667 3.000000 13.125832
## 618 12.7083333 3.000000 14.133385
## 619 13.7500000 3.000000 15.145554
## 620 14.7916667 3.000000 16.161472
## 621 15.8333333 3.000000 17.180474
## 622 16.8750000 3.000000 18.202042
## 623 17.9166667 3.000000 19.225767
## 624 18.9583333 3.000000 20.251322
## 625 20.0000000 3.000000 21.278443
grid %>% 
  ggplot(aes(a1, a2))+
  geom_point(data = filter(grid, rank(dist) <= 10),
             size = 4,
             colour = "red") + 
  geom_point(aes(colour = -dist))

When you overlay the best 10 models back on the original data, they all look pretty good:

ggplot(sim1, aes(x, y))+
  geom_point(size = 2, colour = "grey30") +
  geom_abline(
    aes(intercept = a1, slope = a2, colour = -dist),
    data = filter(grid, rank(dist) <= 10)
  )

Find the best model

You could imagine iteratively making the grid finer and finer until you narrowed in on the best model. But there’s a better way to tackle that problem: a numerical minimisation tool called Newton-Raphson search.

The intuition of Newton-Raphson is pretty simple: you pick a starting point and look around for the steepest slope. You then ski down that slope a little way, and then repeat again and again, until you can’t go any lower. (梯度下降,让距离减小)

In R, we can do that with optim():

model1 <- function(a, data) {
  a[1] + data$x * a[2]
}

measure_distance <- function(mod, data) {
  diff <- data$y - model1(mod, data)
  sqrt(mean(diff ^ 2))
}

# find the parameters to make the return value of measure_distance minimum
best <- optim(c(0, 0), measure_distance, data = sim1)
best
## $par
## [1] 4.222248 2.051204
## 
## $value
## [1] 2.128181
## 
## $counts
## function gradient 
##       77       NA 
## 
## $convergence
## [1] 0
## 
## $message
## NULL
best$par
## [1] 4.222248 2.051204
ggplot(sim1, aes(x, y))+
  geom_point(size = 2, colour = "grey30") +
  geom_abline(intercept = best$par[1], slope = best$par[2])

Don’t worry too much about the details of how optim() works. It’s the intuition that’s important here:

  1. have a function that defines the distance between a model and a dataset
  2. have an algorithm that can minimise that distance by modifying the parameters of the model
  3. find the best model.

The neat thing about this approach is that it will work for any family of models that you can write an equation for.

Alternative way

There’s one more approach that we can use for this model, because it’s a special case of a broader family: linear models. A linear model has the general form y = a_1 + a_2 * x_1 + a_3 * x_2 + … + a_n * x_(n - 1). So this simple model is equivalent to a general linear model where n is 2 and x_1 is x.

R has a tool specifically designed for fitting linear models called lm().

lm() has a special way to specify the model family: formulas.

Formulas look like y ~ x, which lm() will translate to a function like y = a_1 + a_2 * x. We can fit the model and look at the output

sim1_mod <- lm(y ~ x, data = sim1)
coef(sim1_mod)
## (Intercept)           x 
##    4.220822    2.051533

These are exactly the same values we got with optim()!

Behind the scenes lm() doesn’t use optim() but instead takes advantage of the mathematical structure of linear models. Using some connections between geometry, calculus, and linear algebra, lm() actually finds the closest model in a single step, using a sophisticated algorithm. This approach is both faster, and guarantees that there is a global minimum.

Visualising models

For simple models, like the one above, you can figure out what pattern the model captures by carefully studying the model family and the fitted coefficients. And if you ever take a statistics course on modelling, you’re likely to spend a lot of time doing just that. Here, however, we’re going to take a different tack.

Predictions

To visualise the predictions from a model:

  1. generating an evenly spaced grid of values that covers the region where our data lies.
    • modelr::data_grid().
    • Its first argument is a data frame
    • for each subsequent argument, it finds the unique variables
    • generates all combinations:
grid <- sim1 %>% 
  data_grid(x)
grid
## # A tibble: 10 x 1
##        x
##    <int>
##  1     1
##  2     2
##  3     3
##  4     4
##  5     5
##  6     6
##  7     7
##  8     8
##  9     9
## 10    10
  1. Next we add predictions.
    • modelr::add_predictions()
    • takes a data frame and a model.
    • It adds the predictions from the model to a new column in the data frame:
sim1_mod <- lm(y ~ x, data = sim1)
coef(sim1_mod)
## (Intercept)           x 
##    4.220822    2.051533
grid <- grid %>% 
  add_predictions(sim1_mod)
grid
## # A tibble: 10 x 2
##        x  pred
##    <int> <dbl>
##  1     1  6.27
##  2     2  8.32
##  3     3 10.4 
##  4     4 12.4 
##  5     5 14.5 
##  6     6 16.5 
##  7     7 18.6 
##  8     8 20.6 
##  9     9 22.7 
## 10    10 24.7
  1. Next, we plot the predictions.
ggplot(sim1, aes(x)) +
  geom_point(aes(y = y)) +
  geom_line(aes(y = pred), data = grid, colour = "red", size = 1)

You might wonder about all this extra work compared to just using geom_abline(). But the advantage of this approach is that it will work with any model in R, from the simplest to the most complex. You’re only limited by your visualisation skills.

Residuals

The flip-side of predictions are residuals.

  • The predictions tells you the pattern that the model has captured
  • The residuals tell you waht the model has missed.

The residuals are just the distances between the observed and predicted values that we computed above.

\[residual = observed - predicted\]

We add residuals to the data with add_residuals(), which works much like add_predictions().

Note, however, that we use the original dataset, not a manufactured grid. This is because to compute residuals we need actual y values.

sim1 <- sim1 %>% 
  add_residuals(sim1_mod)
sim1
## # A tibble: 30 x 3
##        x     y    resid
##    <int> <dbl>    <dbl>
##  1     1  4.20 -2.07   
##  2     1  7.51  1.24   
##  3     1  2.13 -4.15   
##  4     2  8.99  0.665  
##  5     2 10.2   1.92   
##  6     2 11.3   2.97   
##  7     3  7.36 -3.02   
##  8     3 10.5   0.130  
##  9     3 10.5   0.136  
## 10     4 12.4   0.00763
## # … with 20 more rows

There are a few different ways to understand what the residuals tell us about the model.

  1. One way is to simply draw a frequency polygon to help us understand the spread of the residuals:
ggplot(sim1, aes(resid)) +
  geom_freqpoly(binwidth = 0.5)

This helps you calibrate the quality of the model: how far away are the predictions from the observed values? Note that the average of the residual will always be 0.

You’ll often want to recreate plots using the residuals instead of the original predictor. You’ll see a lot of that in the next chapter.

ggplot(sim1, aes(x, resid)) +
  geom_ref_line(h = 0) +
  geom_point()

This looks like random noise, suggesting that our model has done a good job of capturing the patterns in the dataset.

Formulas and model families

In R, formulas provide a general way of getting “special behaviour”. Rather than evaluating the values of the variables right away, they capture them so they can be interpreted by the function.

The majority of modelling functions in R use a standard conversion from formulas to functions.

You’ve seen one simple conversion already: y ~ x is translated to y = a_1 + a_2 * x. If you want to see what R actually does, you can use the model_matrix() function:

df <- tribble(
  ~y, ~x1, ~x2,
  4, 2, 5,
  5, 1, 6
)
df
## # A tibble: 2 x 3
##       y    x1    x2
##   <dbl> <dbl> <dbl>
## 1     4     2     5
## 2     5     1     6
# get data
model_matrix(df, y ~ x1)
## # A tibble: 2 x 2
##   `(Intercept)`    x1
##           <dbl> <dbl>
## 1             1     2
## 2             1     1
# get coefficient
lm(y ~ x1, data = df)
## 
## Call:
## lm(formula = y ~ x1, data = df)
## 
## Coefficients:
## (Intercept)           x1  
##           6           -1

The way that R adds the intercept to the model is just by having a column that is full of ones. By default, R will always add this column. If you don’t want, you need to explicitly drop it with -1:

model_matrix(df, y ~ x1 - 1)
## # A tibble: 2 x 1
##      x1
##   <dbl>
## 1     2
## 2     1

The model matrix grows in an unsurprising way when you add more variables to the the model:

df
## # A tibble: 2 x 3
##       y    x1    x2
##   <dbl> <dbl> <dbl>
## 1     4     2     5
## 2     5     1     6
model_matrix(df, y ~ x1 + x2)
## # A tibble: 2 x 3
##   `(Intercept)`    x1    x2
##           <dbl> <dbl> <dbl>
## 1             1     2     5
## 2             1     1     6

This formula notation is sometimes called “Wilkinson-Rogers notation”, and was initially described in Symbolic Description of Factorial Models for Analysis of Variance, by G. N. Wilkinson and C. E. Rogers https://www.jstor.org/stable/2346786. It’s worth digging up and reading the original paper if you’d like to understand the full details of the modelling algebra.

The following sections expand on how this formula notation works for categorical variables, interactions, and transformation.

Categorical variables

Imagine you have a formula like y ~ sex, where sex could either be male or female.

It doesn’t make sense to convert that to a formula like y = x_0 + x_1 * sex because sex isn’t a number - you can’t multiply it!

Instead what R does is convert it to y = x_0 + x_1 * sexmale where sexmale is 1 if sex is male and 0 otherwise:

df <- tribble(
  ~ sex, ~ response,
  "male", 1,
  "female", 2,
  "male", 1
)
df
## # A tibble: 3 x 2
##   sex    response
##   <chr>     <dbl>
## 1 male          1
## 2 female        2
## 3 male          1
model_matrix(df, response ~ sex)
## # A tibble: 3 x 2
##   `(Intercept)` sexmale
##           <dbl>   <dbl>
## 1             1       1
## 2             1       0
## 3             1       1

You might wonder why R also doesn’t create a sexfemale column. The problem is that would create a column that is perfectly predictable based on the other columns (i.e. sexfemale = 1 - sexmale). Unfortunately the exact details of why this is a problem is beyond the scope of this book, but basically it creates a model family that is too flexible, and will have infinitely many models that are equally close to the data.

Fortunately, however, if you focus on visualising predictions you don’t need to worry about the exact parameterisation. Let’s look at some data and models to make that concrete. Here’s the sim2 dataset from modelr:

sim2
## # A tibble: 40 x 2
##    x          y
##    <chr>  <dbl>
##  1 a      1.94 
##  2 a      1.18 
##  3 a      1.24 
##  4 a      2.62 
##  5 a      1.11 
##  6 a      0.866
##  7 a     -0.910
##  8 a      0.721
##  9 a      0.687
## 10 a      2.07 
## # … with 30 more rows
ggplot(sim2) +
  geom_point(aes(x, y))

We can fit a model to it, and generate predictions:

mod2 <- lm(y ~ x, data = sim2)

grid <- sim2 %>% 
  data_grid(x) %>% 
  add_predictions(mod2)

grid
## # A tibble: 4 x 2
##   x      pred
##   <chr> <dbl>
## 1 a      1.15
## 2 b      8.12
## 3 c      6.13
## 4 d      1.91

Effectively, a model with a categorical x will predict the mean value for each category. (Why? Because the mean minimises the root-mean-squared distance.) That’s easy to see if we overlay the predictions on top of the original data:

ggplot(sim2, aes(x)) +
  geom_point(aes(y = y)) +
  geom_point(data = grid, aes(y = pred), colour = "red", size = 4)

You can’t make predictions about levels that you didn’t observe. Sometimes you’ll do this by accident so it’s good to recognise this error message:

Interactions (continuous and categorical)

What happens when you combine a continuous and a categorical variable?

sim3 contains a categorical predictor and a continuous predictor.

We can visualise it with a simple plot:

sim3
## # A tibble: 120 x 5
##       x1 x2      rep      y    sd
##    <int> <fct> <int>  <dbl> <dbl>
##  1     1 a         1 -0.571     2
##  2     1 a         2  1.18      2
##  3     1 a         3  2.24      2
##  4     1 b         1  7.44      2
##  5     1 b         2  8.52      2
##  6     1 b         3  7.72      2
##  7     1 c         1  6.51      2
##  8     1 c         2  5.79      2
##  9     1 c         3  6.07      2
## 10     1 d         1  2.11      2
## # … with 110 more rows
ggplot(sim3, aes(x1, y)) +
  geom_point(aes(colour = x2))

There are two possible models you could fit to this data:

mod1 <- lm(y ~ x1 + x2, data = sim3)
mod1
## 
## Call:
## lm(formula = y ~ x1 + x2, data = sim3)
## 
## Coefficients:
## (Intercept)           x1          x2b          x2c          x2d  
##      1.8717      -0.1967       2.8878       4.8057       2.3596
model_matrix(sim3, y ~ x1 + x2)
## # A tibble: 120 x 5
##    `(Intercept)`    x1   x2b   x2c   x2d
##            <dbl> <dbl> <dbl> <dbl> <dbl>
##  1             1     1     0     0     0
##  2             1     1     0     0     0
##  3             1     1     0     0     0
##  4             1     1     1     0     0
##  5             1     1     1     0     0
##  6             1     1     1     0     0
##  7             1     1     0     1     0
##  8             1     1     0     1     0
##  9             1     1     0     1     0
## 10             1     1     0     0     1
## # … with 110 more rows
mod2 <- lm(y ~ x1 * x2, data = sim3)
coef(mod2)
## (Intercept)          x1         x2b         x2c         x2d      x1:x2b 
##  1.30124266 -0.09302444  7.06937991  4.43089525  0.83455115 -0.76028528 
##      x1:x2c      x1:x2d 
##  0.06815284  0.27727920
model_matrix(sim3, y ~ x1 * x2)
## # A tibble: 120 x 8
##    `(Intercept)`    x1   x2b   x2c   x2d `x1:x2b` `x1:x2c` `x1:x2d`
##            <dbl> <dbl> <dbl> <dbl> <dbl>    <dbl>    <dbl>    <dbl>
##  1             1     1     0     0     0        0        0        0
##  2             1     1     0     0     0        0        0        0
##  3             1     1     0     0     0        0        0        0
##  4             1     1     1     0     0        1        0        0
##  5             1     1     1     0     0        1        0        0
##  6             1     1     1     0     0        1        0        0
##  7             1     1     0     1     0        0        1        0
##  8             1     1     0     1     0        0        1        0
##  9             1     1     0     1     0        0        1        0
## 10             1     1     0     0     1        0        0        1
## # … with 110 more rows
  • +: the model will estimate each effect independent of all the others.

  • *: It’s possible to fit the so-called interaction.

For example, y ~ x1 * x2 is translated to y = a_0 + a_1 * x1 + a_2 * x2 + a_12 * x1 * x2. Note that whenever you use *, both the interaction and the individual components are included in the model.

To visualise these models we need two new tricks:

  1. We have two predictors, so we need to give data_grid() both variables. It finds all the unique values of x1 and x2 and then generates all combinations.
grid <- sim3 %>% 
  data_grid(x1, x2)
grid
## # A tibble: 40 x 2
##       x1 x2   
##    <int> <fct>
##  1     1 a    
##  2     1 b    
##  3     1 c    
##  4     1 d    
##  5     2 a    
##  6     2 b    
##  7     2 c    
##  8     2 d    
##  9     3 a    
## 10     3 b    
## # … with 30 more rows
  1. To generate predictions from both models simultaneously, we can use gather_predictions() which adds each prediction as a row. The complement of gather_predictions() is spread_predictions() which adds each prediction to a new column.
grid <- sim3 %>% 
  data_grid(x1, x2) %>% 
  gather_predictions(mod1, mod2)
grid
## # A tibble: 80 x 4
##    model    x1 x2     pred
##    <chr> <int> <fct> <dbl>
##  1 mod1      1 a      1.67
##  2 mod1      1 b      4.56
##  3 mod1      1 c      6.48
##  4 mod1      1 d      4.03
##  5 mod1      2 a      1.48
##  6 mod1      2 b      4.37
##  7 mod1      2 c      6.28
##  8 mod1      2 d      3.84
##  9 mod1      3 a      1.28
## 10 mod1      3 b      4.17
## # … with 70 more rows

We can visualise the results for both models on one plot using facetting:

ggplot(sim3, aes(x1, y, colour = x2)) +
  geom_point() +
  geom_line(data = grid, aes(y = pred)) +
  facet_wrap(~ model)

Note that the model that uses + has the same slope for each line, but different intercepts. The model that uses * has a different slope and intercept for each line.

Which model is better for this data? We can take look at the residuals.

Here I’ve facetted by both model and x2 because it makes it easier to see the pattern within each group.

sim3
## # A tibble: 120 x 5
##       x1 x2      rep      y    sd
##    <int> <fct> <int>  <dbl> <dbl>
##  1     1 a         1 -0.571     2
##  2     1 a         2  1.18      2
##  3     1 a         3  2.24      2
##  4     1 b         1  7.44      2
##  5     1 b         2  8.52      2
##  6     1 b         3  7.72      2
##  7     1 c         1  6.51      2
##  8     1 c         2  5.79      2
##  9     1 c         3  6.07      2
## 10     1 d         1  2.11      2
## # … with 110 more rows
sim3_resid <- sim3 %>% 
  gather_residuals(mod1, mod2)

sim3_resid
## # A tibble: 240 x 7
##    model    x1 x2      rep      y    sd   resid
##    <chr> <int> <fct> <int>  <dbl> <dbl>   <dbl>
##  1 mod1      1 a         1 -0.571     2 -2.25  
##  2 mod1      1 a         2  1.18      2 -0.491 
##  3 mod1      1 a         3  2.24      2  0.562 
##  4 mod1      1 b         1  7.44      2  2.87  
##  5 mod1      1 b         2  8.52      2  3.96  
##  6 mod1      1 b         3  7.72      2  3.16  
##  7 mod1      1 c         1  6.51      2  0.0261
##  8 mod1      1 c         2  5.79      2 -0.691 
##  9 mod1      1 c         3  6.07      2 -0.408 
## 10 mod1      1 d         1  2.11      2 -1.92  
## # … with 230 more rows
ggplot(sim3_resid, aes(x1, resid, colour = x2)) +
  geom_point() +
  facet_grid(model ~ x2)

There is little obvious pattern in the residuals for mod2. The residuals for mod1 show that the model has clearly missed some pattern in b, and less so, but still present is pattern in c, and d.

You might wonder if there’s a precise way to tell which of mod1 or mod2 is better. There is, but it requires a lot of mathematical background, and we don’t really care. Here, we’re interested in a qualitative assessment of whether or not the model has captured the pattern that we’re interested in.

Interactions (two continuous)

Let’s take a look at the equivalent model for two continuous variables. Initially things proceed almost identically to the previous example:

sim4
## # A tibble: 300 x 4
##       x1     x2   rep       y
##    <dbl>  <dbl> <int>   <dbl>
##  1    -1 -1         1  4.25  
##  2    -1 -1         2  1.21  
##  3    -1 -1         3  0.353 
##  4    -1 -0.778     1 -0.0467
##  5    -1 -0.778     2  4.64  
##  6    -1 -0.778     3  1.38  
##  7    -1 -0.556     1  0.975 
##  8    -1 -0.556     2  2.50  
##  9    -1 -0.556     3  2.70  
## 10    -1 -0.333     1  0.558 
## # … with 290 more rows
mod1 <- lm(y ~ x1 + x2, data = sim4)
mod1
## 
## Call:
## lm(formula = y ~ x1 + x2, data = sim4)
## 
## Coefficients:
## (Intercept)           x1           x2  
##     0.03546      1.82167     -2.78252
mod2 <- lm(y ~ x1 * x2, data = sim4)
mod2
## 
## Call:
## lm(formula = y ~ x1 * x2, data = sim4)
## 
## Coefficients:
## (Intercept)           x1           x2        x1:x2  
##     0.03546      1.82167     -2.78252      0.95228
grid <- sim4 %>% 
  data_grid(
    x1 = seq_range(x1, 5),
    x2 = seq_range(x2, 5)
    ) %>% 
  gather_predictions(mod1, mod2)
grid  
## # A tibble: 50 x 4
##    model    x1    x2   pred
##    <chr> <dbl> <dbl>  <dbl>
##  1 mod1   -1    -1    0.996
##  2 mod1   -1    -0.5 -0.395
##  3 mod1   -1     0   -1.79 
##  4 mod1   -1     0.5 -3.18 
##  5 mod1   -1     1   -4.57 
##  6 mod1   -0.5  -1    1.91 
##  7 mod1   -0.5  -0.5  0.516
##  8 mod1   -0.5   0   -0.875
##  9 mod1   -0.5   0.5 -2.27 
## 10 mod1   -0.5   1   -3.66 
## # … with 40 more rows

Note my use of seq_range() inside data_grid(). Instead of using every unique value of x, I’m going to use a regularly spaced grid of five values between the minimum and maximum numbers.

There are two other useful arguments to seq_range():

  • pretty = TRUE will generate a “pretty” sequence, i.e. something that looks nice to the human eye. This is useful if you want to produce tables of output:
seq_range(c(0.0123, 0.923423), n = 5)
## [1] 0.0123000 0.2400808 0.4678615 0.6956423 0.9234230
seq_range(c(0.0123, 0.923423), n = 5, pretty = TRUE)
## [1] 0.0 0.2 0.4 0.6 0.8 1.0
  • trim = 0.1 will trim off 10% of the tail values. This is useful if the variables have a long tailed distribution and you want to focus on generating values near the center:
x1 <- rcauchy(100)
x1 <- tibble(x1)
ggplot(x1, aes(x1)) +
  geom_freqpoly()
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

x1 <- x1$x1
seq_range(x1, n = 5)
## [1] -75.466451 -40.407920  -5.349389  29.709142  64.767673
seq_range(x1, n = 5, trim = 0.1)
## [1] -4.9474863 -0.9232129  3.1010606  7.1253340 11.1496074
seq_range(x1, n = 5, trim = 0.25)
## [1] -2.4376807 -1.0551045  0.3274718  1.7100481  3.0926243
seq_range(x1, n = 5, trim = 0.50)
## [1] -1.5680331 -0.9106162 -0.2531994  0.4042175  1.0616344
  • expand = 0.1 is in some sense the opposite of trim() it expands the range by 10%.
x2 <- c(0, 1)
seq_range(x2, n = 5)
## [1] 0.00 0.25 0.50 0.75 1.00
seq_range(x2, n = 5, expand = 0.1)
## [1] -0.050  0.225  0.500  0.775  1.050
seq_range(x2, n = 5, expand = 0.25)
## [1] -0.1250  0.1875  0.5000  0.8125  1.1250
seq_range(x2, n = 5, expand = 0.5)
## [1] -0.250  0.125  0.500  0.875  1.250

Next let’s try and visualise that model. We have two continuous predictors, so you can imagine the model like a 3d surface. We could display that using geom_tile():

ggplot(grid, aes(x1, x2)) +
  geom_tile(aes(fill = pred)) +
  facet_wrap(~ model)

That doesn’t suggest that the models are very different! But that’s partly an illusion: our eyes and brains are not very good at accurately comparing shades of colour.

Instead of looking at the surface from the top, we could look at it from either side, showing multiple slices:

ggplot(grid, aes(x1, pred, colour = x2, group = x2)) +
  geom_line() +
  facet_wrap(~ model)

ggplot(grid, aes(x2, pred, colour = x1, group = x1)) +
  geom_line() +
  facet_wrap(~ model)

This shows you that interaction between two continuous variables works basically the same way as for a categorical and continuous variable. An interaction says that there’s not a fixed offset: you need to consider both values of x1 and x2 simultaneously in order to predict y.

You can see that even with just two continuous variables, coming up with good visualisations are hard. But that’s reasonable: you shouldn’t expect it will be easy to understand how three or more variables simultaneously interact! But again, we’re saved a little because we’re using models for exploration, and you can gradually build up your model over time. The model doesn’t have to be perfect, it just has to help you reveal a little more about your data.

Transformations

You can also perform transformations inside the model formula.

If your transformation involves +, *, ^, or -, you’ll need to wrap it in I() so R doesn’t treat it like part of the model specification.

  • y ~ I(x ^ 2) + x —> y = a_1 + a_2 * x + a_3 * x^2
  • y ~ x ^ 2 + x —> y ~ x * x + x —> y = a_1 + a_2 * x

Again, if you get confused about what your model is doing, you can always use model_matrix() to see exactly what equation lm() is fitting:

df <- tribble(
  ~y, ~x,
  1, 1,
  2, 2,
  3, 3
)

df
## # A tibble: 3 x 2
##       y     x
##   <dbl> <dbl>
## 1     1     1
## 2     2     2
## 3     3     3
model_matrix(df, y ~ x^2 + x)
## # A tibble: 3 x 2
##   `(Intercept)`     x
##           <dbl> <dbl>
## 1             1     1
## 2             1     2
## 3             1     3
model_matrix(df, y ~ I(x^2) + x)
## # A tibble: 3 x 3
##   `(Intercept)` `I(x^2)`     x
##           <dbl>    <dbl> <dbl>
## 1             1        1     1
## 2             1        4     2
## 3             1        9     3

Transformations are useful because you can use them to approximate non-linear functions.

If you’ve taken a calculus class, you may have heard of Taylor’s theorem which says you can approximate any smooth function with an infinite sum of polynomials.

That means you can use a polynomial function to get arbitrarily close to a smooth function by fitting an equation like y = a_1 + a_2 * x + a_3 * x^2 + a_4 * x ^ 3.

Typing that sequence by hand is tedious, so R provides a helper function: poly():

model_matrix(df, y ~ poly(x, 2))
## # A tibble: 3 x 3
##   `(Intercept)` `poly(x, 2)1` `poly(x, 2)2`
##           <dbl>         <dbl>         <dbl>
## 1             1     -7.07e- 1         0.408
## 2             1     -7.85e-17        -0.816
## 3             1      7.07e- 1         0.408

However there’s one major problem with using poly(): outside the range of the data, polynomials rapidly shoot off to positive or negative infinity. One safer alternative is to use the natural spline, splines::ns().

library(splines)
model_matrix(df, y ~ ns(x, 2))
## # A tibble: 3 x 3
##   `(Intercept)` `ns(x, 2)1` `ns(x, 2)2`
##           <dbl>       <dbl>       <dbl>
## 1             1       0           0    
## 2             1       0.566      -0.211
## 3             1       0.344       0.771

Let’s see what that looks like when we try and approximate a non-linear function:

sim5 <- tibble(
  x = seq(0, 3.5 * pi, length = 50),
  y = 4 * sin(x) + rnorm(length(x))
)
sim5
## # A tibble: 50 x 2
##        x      y
##    <dbl>  <dbl>
##  1 0     -1.06 
##  2 0.224  0.342
##  3 0.449  1.56 
##  4 0.673  1.38 
##  5 0.898  4.14 
##  6 1.12   3.59 
##  7 1.35   3.06 
##  8 1.57   4.12 
##  9 1.80   5.39 
## 10 2.02   3.34 
## # … with 40 more rows
ggplot(sim5, aes(x, y)) +
  geom_point()

I’m going to fit five models to this data.

mod1 <- lm(y ~ ns(x, 1), data = sim5)
mod2 <- lm(y ~ ns(x, 2), data = sim5)
mod3 <- lm(y ~ ns(x, 3), data = sim5)
mod4 <- lm(y ~ ns(x, 4), data = sim5)
mod5 <- lm(y ~ ns(x, 5), data = sim5)
grid <- sim5 %>% 
  data_grid(x = seq_range(x, n = 50, expand = 0.1)) %>% 
  gather_predictions(mod1, mod2, mod3, mod4, mod5, .pred = "y")

grid
## # A tibble: 250 x 3
##    model       x     y
##    <chr>   <dbl> <dbl>
##  1 mod1  -0.550   1.57
##  2 mod1  -0.303   1.52
##  3 mod1  -0.0561  1.46
##  4 mod1   0.191   1.41
##  5 mod1   0.438   1.36
##  6 mod1   0.684   1.31
##  7 mod1   0.931   1.26
##  8 mod1   1.18    1.21
##  9 mod1   1.42    1.16
## 10 mod1   1.67    1.11
## # … with 240 more rows
ggplot(sim5, aes(x, y)) +
  geom_point() +
  geom_line(data = grid, colour = "red") +
  facet_wrap(~ model)

Notice that the extrapolation outside the range of the data is clearly bad. This is the downside to approximating a function with a polynomial.

But this is a very real problem with every model: the model can never tell you if the behaviour is true when you start extrapolating outside the range of the data that you have seen. You must rely on theory and science.

Missing values

Missing values obviously can not convey any information about the relationship between the variables, so modelling functions will drop any rows that contain missing values. R’s default behaviour is to silently drop them, but options(na.action = na.warn) (run in the prerequisites), makes sure you get a warning.

df <- tribble(
  ~x, ~y,
  1, 2.2,
  2, NA,
  3, 3.5,
  4, 8.3,
  NA, 10
)
df
## # A tibble: 5 x 2
##       x     y
##   <dbl> <dbl>
## 1     1   2.2
## 2     2  NA  
## 3     3   3.5
## 4     4   8.3
## 5    NA  10
mod <- lm(y ~ x, data = df)
## Warning: Dropping 2 rows with missing values
mod
## 
## Call:
## lm(formula = y ~ x, data = df)
## 
## Coefficients:
## (Intercept)            x  
##     -0.2286       1.8357

To suppress the warning, set na.action = na.exclude:

mod <- lm(y ~ x, data = df, na.action = na.exclude)
mod
## 
## Call:
## lm(formula = y ~ x, data = df, na.action = na.exclude)
## 
## Coefficients:
## (Intercept)            x  
##     -0.2286       1.8357

You can always see exactly how many observations were used with nobs():

nobs(mod)
## [1] 3

Other model families

This chapter has focussed exclusively on the class of linear models, which assume a relationship of the form y = a_1 * x1 + a_2 * x2 + … + a_n * xn.

Linear models additionally assume that the residuals have a normal distribution, which we haven’t talked about.

There are a large set of model classes that extend the linear model in various interesting ways. Some of them are:

These models all work similarly from a programming perspective. Once you’ve mastered linear models, you should find it easy to master the mechanics of these other model classes.

Being a skilled modeller is a mixture of some good general principles and having a big toolbox of techniques.